import os
import sys
from glob import glob

ckpt_dir = "checkpoint_filtered"

rank_size = int(sys.argv[1]) if len(sys.argv) > 1 else 4096
print("rank_size: ", rank_size)
rank_dirs = [f"rank_{i}" for i in range(rank_size)]

check_succeed = True
for rank_dir in rank_dirs:
    rank_dir = os.path.join(ckpt_dir, rank_dir)
    if not os.path.exists(rank_dir):
        print(rank_dir + " not found!")
        check_succeed = False
        continue

    ckpt = glob(os.path.join(rank_dir, "*.ckpt"))
    if ckpt and len(ckpt) == 1:
        continue
    if ckpt and len(ckpt) > 1:
        print(rank_dir + " has more than 1 ckpt!")
    else:
        print(rank_dir + " no ckpt!")
    check_succeed = False
if check_succeed:
    print("check succeed!")
else:
    print("check failed!")
